#############################################################################################
# NKPH summary methods
#############################################################################################


using Plots: plot, plot!
using LaTeXStrings
using Latexify


OUTPUT_DIR = "../output/"


function _init_nkphopt(nkphopt::NKPHEstimationOptimizer; 
        display=false, draw_seed=1234, draw_init_method=0, draw_init_std=0.0
    )
    # wrapper to initialize NKPH model, optionally displaying solution info
    initialize_empirical_moments!(nkphopt)
    initialize_model!(nkphopt)
    initialize_targets!(nkphopt)
    setup_nlopt!(nkphopt)

    # initial values
    if display
        println("setting up initial targets")
    end
    p0_arr = draw_init_params(draw_seed, nkphopt, draw_init_method, draw_init_std)

    set_continuation_targets!(p0_arr, nkphopt)
    if display
        # display additional model info
        println("max root: ", maximum(abs.(nkphopt.nkphsolvr.root_vec)),)
        println("min/max Upsilon eigenvals: ", nkphopt.nkphtgts.Upsilon_min_pos_eigval[1],
            " ", nkphopt.nkphtgts.Upsilon_max_neg_eigval[1])
        println("min M eigenval: ", nkphopt.nkphtgts.M_min_eigval[1])
        loss_val = nkphopt.nkphtgts.loss_val[1]
        println("loss value: ", loss_val)
    end
    return nkphopt
end


function load_nkph_model_estimates(model_fname, is_high_crisis; all_maturity=true, display=false)
    # wrapper to load and solve estimated model, using default inputs
    rates_fname = "yield_macro_data"
    if is_high_crisis
        regs_fname = "habitat_regressions_crisis"
        alt_regs_fname = "alt_habitat_regressions_crisis"
        rates_start_date = "2008m1"
        rates_end_date = "2012m1"
    else
        regs_fname = "habitat_regressions_non_crisis"
        alt_regs_fname = "alt_habitat_regressions_non_crisis"
        rates_start_date = "1986m6"
        rates_end_date = "2007m1"
    end

    model_fpath = MODEL_DIR * model_fname * ".xlsx"
    rates_fpath = MOMENT_DIR * rates_fname * ".csv"
    regs_fpath = MOMENT_DIR * regs_fname * ".csv"
    alt_regs_fpath = MOMENT_DIR * alt_regs_fname * ".csv"

    nkphopt = NKPHEstimationOptimizer(
        model_fpath, rates_fpath, regs_fpath, alt_regs_fpath,
        rates_start_date, rates_end_date
    )

    if all_maturity
        nkphopt.tau_arr_long = Vector{Float64}(range(1, 30, step=1))
        nkphopt.tau_arr_til = Vector{Float64}(range(1, 30, step=1))
        tau_arr_reg = Vector{Float64}(range(1, 20, step=1))
        tau_arr_reg = vcat(tau_arr_reg, Vector{Float64}(range(23, 30, step=1)))
        nkphopt.tau_arr_reg = tau_arr_reg
    end

    return _init_nkphopt(nkphopt; display=display)
end



# helper functions to compute yield/price/holdings responses to shocks
# not used in solution/estimation algorithm
function _get_bond_affine_coefs(tau::Float64, nkphsolvr::NKPHSolver)
    # get affine coefficients: A/A_til (function of single value tau)
    # pull out objects
    J, M, ei_vec, ed_vec = nkphsolvr.J, nkphsolvr.M, nkphsolvr.ei_vec, nkphsolvr.ed_vec
    # eigendecomps
    M_eig = eigen(M)

    # loop over maturities tau
    int_eM_tau = calc_exp_matrix_tau_integral(tau, M_eig.values, M_eig.vectors)
    eM_tau = calc_exp_matrix_tau(tau, M_eig.values, M_eig.vectors)
    A_tau_arr = int_eM_tau * ei_vec
    A_tau_arr_til = A_tau_arr - eM_tau * ed_vec
    return A_tau_arr, A_tau_arr_til
end
function _get_bond_affine_coefs(tau_arr::Vector{Float64}, nkphsolvr::NKPHSolver)
    # get affine coefficients: A/A_til (function of tau_arr)
    # pull out objects
    J, M, ei_vec, ed_vec = nkphsolvr.J, nkphsolvr.M, nkphsolvr.ei_vec, nkphsolvr.ed_vec
    # eigendecomps
    M_eig = eigen(M)

    # loop over maturities tau
    N_tau = length(tau_arr)
    A_tau_arr = zeros(J, N_tau)
    A_tau_arr_til = zeros(J, N_tau)
    for i=1:N_tau
        tau = tau_arr[i]
        int_eM_tau = calc_exp_matrix_tau_integral(tau, M_eig.values, M_eig.vectors)
        eM_tau = calc_exp_matrix_tau(tau, M_eig.values, M_eig.vectors)
        A_tau_arr[:, i] = int_eM_tau * ei_vec
        A_tau_arr_til[:, i] = A_tau_arr[:, i] - eM_tau * ed_vec
    end
    return A_tau_arr, A_tau_arr_til
end
function calc_yield_response(y::Vector{Float64}, tau_arr::Vector{Float64},
        nkphsolvr::NKPHSolver)
    # response of riskless/risky yield curve to state y
    # pull out objects
    J, M, ei_vec, ed_vec = nkphsolvr.J, nkphsolvr.M, nkphsolvr.ei_vec, nkphsolvr.ed_vec
    A_tau_arr, A_tau_arr_til = _get_bond_affine_coefs(tau_arr, nkphsolvr)
    yc_resp = (A_tau_arr' * y) ./ tau_arr
    # "effective" risky yields
    yc_resp_til = ((A_tau_arr_til .+ ed_vec)' * y) ./ tau_arr
    #yc_resp_til = (A_tau_arr_til' * y) ./ tau_arr
    return yc_resp, yc_resp_til
end
function calc_forward_rate_response(y::Vector{Float64}, tau_arr::Vector{Float64},
        nkphsolvr::NKPHSolver)
    # response of riskless/risky forward rates to state y
    # pull out objects
    M, ei_vec = nkphsolvr.M, nkphsolvr.ei_vec
    A_tau_arr, A_tau_arr_til = _get_bond_affine_coefs(tau_arr, nkphsolvr)
    f_resp = (-M * A_tau_arr .+ ei_vec)' * y
    f_resp_til = (-M * A_tau_arr_til .+ ei_vec)' * y
    return f_resp, f_resp_til
end

function calc_bond_expected_return_response(y::Vector{Float64}, tau_arr::Vector{Float64},
        nkphsolvr::NKPHSolver)
    # response of riskless/risky expected returns to state y
    # pull out objects
    M, ei_vec = nkphsolvr.M, nkphsolvr.ei_vec
    Gamma = nkphsolvr.sdyn.Gamma
    A_tau_arr, A_tau_arr_til = _get_bond_affine_coefs(tau_arr, nkphsolvr)
    GammaT_M = Gamma' - M
    mu_resp = (GammaT_M * A_tau_arr .+ ei_vec)' * y
    mu_resp_til = (GammaT_M * A_tau_arr_til .+ ei_vec)' * y
    return mu_resp, mu_resp_til
end


################################################################
# habitat portfolio response
function _get_alpha_tau_func(tau_arr::Vector{Float64}, nkphsolvr::NKPHSolver)
    # get alpha_j for H and F as a function of maturity
    a0, a1 = nkphsolvr.alpha0[1], nkphsolvr.alpha1[1]
    a0_til, a1_til = nkphsolvr.alpha0_til[1], nkphsolvr.alpha1_til[1]
    alpha_arr = [calc_alpha_tau(tau, a0, a1) for tau in tau_arr]
    alpha_arr_til = [calc_alpha_tau(tau, a0_til, a1_til) for tau in tau_arr]
    return alpha_arr, alpha_arr_til
end

function _get_Theta_tau_func(tau_arr::Vector{Float64}, nkphsolvr::NKPHSolver)
    # get Theta_j for H and F as a function of maturity
    J, Theta0, Theta1, Theta0_til, Theta1_til = nkphsolvr.J, nkphsolvr.Theta0, nkphsolvr.Theta1,
        nkphsolvr.Theta0_til, nkphsolvr.Theta1_til

    # loop over maturities tau
    N_tau = length(tau_arr)
    Theta_arr = zeros(J, N_tau)
    Theta_arr_til = zeros(J, N_tau)
    for i=1:N_tau
        tau = tau_arr[i]
        Theta_arr[:, i] = calc_Theta_tau(tau, Theta0, Theta1)
        Theta_arr_til[:, i] = calc_Theta_tau(tau, Theta0_til, Theta1_til)
    end
    return Theta_arr, Theta_arr_til
end

function calc_bond_elastic_portfolio_response(y::Vector{Float64}, tau_arr::Vector{Float64},
        nkphsolvr::NKPHSolver)
    # response of habitat bond investors to y (endogenous price elastic response only)
    A_tau_arr, A_tau_arr_til = _get_bond_affine_coefs(tau_arr, nkphsolvr)
    alpha_arr, alpha_arr_til = _get_alpha_tau_func(tau_arr, nkphsolvr)
    return alpha_arr .* (A_tau_arr' * y), alpha_arr_til .* (A_tau_arr_til' * y)
end

function calc_bond_factor_portfolio_response(y::Vector{Float64}, tau_arr::Vector{Float64},
        nkphsolvr::NKPHSolver)
    # response of habitat bond investors to y (exogenous factor response only)
    Theta_arr, Theta_arr_til = _get_Theta_tau_func(tau_arr, nkphsolvr)
    return -Theta_arr' * y, -Theta_arr_til' * y
end

function calc_bond_portfolio_response(y::Vector{Float64}, tau_arr::Vector{Float64},
        nkphsolvr::NKPHSolver)
    # response of habitat bond investors to y (combined endog and exog)
    Z_endog_resp, Z_endog_resp_til = calc_bond_elastic_portfolio_response(y, tau_arr, nkphsolvr)
    Z_exog_resp, Z_exog_resp_til = calc_bond_factor_portfolio_response(y, tau_arr, nkphsolvr)
    return Z_endog_resp + Z_exog_resp, Z_endog_resp_til + Z_exog_resp_til
end
################################################################

# IRF methods
function calc_irf(
        y0::Vector{Float64},
        t_arr::Vector{Float64},
        nkphsolvr::NKPHSolver
    )
    # state and jump dynamics given y0 shock
    Nt = length(t_arr)
    J, N = nkphsolvr.J, nkphsolvr.N
    y_mat = zeros(J, Nt)
    x_mat = zeros(N-J, Nt)
    Gamma = nkphsolvr.sdyn.Gamma
    Omega = nkphsolvr.sdyn.Omega
    for i=1:Nt
        t = t_arr[i]
        y_mat[:, i] = exp(-t .* Gamma) * y0
        x_mat[:, i] = Omega * y_mat[:, i]
    end
    return y_mat, x_mat
end

function calc_longrun_effect(
        y0::Vector{Float64},
        nkphsolvr::NKPHSolver
    )
    # state and jump longrun effect given y0 shock
    y_oo = nkphsolvr.sdyn.Gamma \ y0
    x_oo = nkphsolvr.sdyn.Omega * y_oo
    return y_oo, x_oo
end

function calc_longrun_volatilty(nkphmmts::NKPHMoments)
    # long-run std dev of state and jump variables
    Sig_oo = nkphmmts.Sig_oo
    Omega = nkphmmts.nkphsolvr.sdyn.Omega
    sd_y_oo = sqrt.(diag(Sig_oo))
    sd_x_oo = sqrt.(diag(Omega * Sig_oo * Omega'))
    return sd_y_oo, sd_x_oo
end

function calc_yield_response_irf(
        y0::Vector{Float64},
        t_arr::Vector{Float64},
        tau_val::Float64,
        nkphsolvr::NKPHSolver
    )
    # response of riskless/risky yield over time (with specific maturity)
    # pull out objects
    J, M, ei_vec, ed_vec = nkphsolvr.J, nkphsolvr.M, nkphsolvr.ei_vec, nkphsolvr.ed_vec

    # state irf
    y_mat, _ = calc_irf(y0, t_arr, nkphsolvr)
    Nt = length(t_arr)
    yld_resp_irf = zeros(Nt)
    yld_resp_til_irf = zeros(Nt)
    # affine coeffs
    A_tau_arr, A_tau_arr_til = _get_bond_affine_coefs(tau_val, nkphsolvr)
    # loop over state irf and compute yield response
    for i=1:Nt
        y = y_mat[:, i]
        yld_resp_irf[i] = (A_tau_arr' * y) / tau_val
        # "effective" risky yields
        yld_resp_til_irf[i] = ((A_tau_arr_til .+ ed_vec)' * y) / tau_val
    end
    return yld_resp_irf, yld_resp_til_irf
end


function _compare_model_moment(idx_group::Int, nkphopt::NKPHEstimationOptimizer)
    # get model vs data for a given moment group
    var_i = nkphopt.df_target_covs[idx_group, :var_i]
    var_j = nkphopt.df_target_covs[idx_group, :var_j]
    cov_type = nkphopt.df_target_covs[idx_group, :cov_type]
    coeff_type = nkphopt.df_target_covs[idx_group, :coeff_type]

    idx_moments = nkphopt.mmtloss_group_lookup[idx_group]
    N_moments = length(idx_moments)
    moment_mat = Matrix{Any}(undef, N_moments, 7)
    for (row_idx, idx_moment) in enumerate(idx_moments)
        mmt = nkphopt.nkphtgts.mmtloss_arr[idx_moment]
        tau = mmt.mmtvarL.tau
        moment_vec = [var_i, var_j, cov_type, coeff_type, tau, mmt.bhat[1], mmt.b[1]]
        moment_mat[row_idx, :] = moment_vec
    end
    df_moment_mat = DataFrame(moment_mat, [:var_i, :var_j, :cov_type, :coeff_type,
        :tau, :bhat, :b])
    return df_moment_mat
end


function _get_moment_group_idx(var_i::String, var_j::String, nkphopt::NKPHEstimationOptimizer)
    N_moment_groups = size(nkphopt.df_target_covs, 1)
    idx = 0
    for idx_group=1:N_moment_groups
        var_i_idx = nkphopt.df_target_covs[idx_group, :var_i]
        var_j_idx = nkphopt.df_target_covs[idx_group, :var_j]
        if var_i == var_i_idx && var_j == var_j_idx
            idx = idx_group
            break
        end
    end
    return idx
end
function _get_moment_group_idx(coeff_type::String, nkphopt::NKPHEstimationOptimizer)
    N_moment_groups = size(nkphopt.df_target_covs, 1)
    idx = 0
    for idx_group=1:N_moment_groups
        coeff_type_idx = nkphopt.df_target_covs[idx_group, :coeff_type]
        if ~ismissing(coeff_type_idx) && coeff_type == coeff_type_idx
            idx = idx_group
            break
        end
    end
    return idx
end

function _get_scalar_moment_indices(nkphopt::NKPHEstimationOptimizer)
    N_moment_groups = size(nkphopt.df_target_covs, 1)
    N_scalar_moments = sum(nkphopt.df_target_covs[:, :moment_type_i] .== "scalar")
    scalar_mmt_idxs = zeros(Int, N_scalar_moments)
    row_idx = 0
    for idx_group=1:N_moment_groups
        if nkphopt.df_target_covs[idx_group, :moment_type_i] != "scalar"
            continue
        end
        scalar_mmt_idxs[row_idx+=1] = idx_group
    end
    return scalar_mmt_idxs
end
function _get_nonscalar_moment_indices(nkphopt::NKPHEstimationOptimizer)
    N_moment_groups = size(nkphopt.df_target_covs, 1)
    N_scalar_moments = sum(nkphopt.df_target_covs[:, :moment_type_i] .== "scalar")
    nonscalar_mmt_idxs = zeros(Int, N_moment_groups-N_scalar_moments)
    row_idx = 0
    for idx_group=1:N_moment_groups
        if nkphopt.df_target_covs[idx_group, :moment_type_i] == "scalar"
            continue
        end
        nonscalar_mmt_idxs[row_idx+=1] = idx_group
    end
    return nonscalar_mmt_idxs
end


function _get_scalar_moments(nkphopt::NKPHEstimationOptimizer)
    # collect all scalar moments in table
    scalar_mmt_idxs = _get_scalar_moment_indices(nkphopt)
    N_scalar_moments = length(scalar_mmt_idxs)
    moment_mat = Matrix{Any}(undef, N_scalar_moments, 7)
    row_idx = 0
    for idx_group in scalar_mmt_idxs
        if length(nkphopt.mmtloss_group_lookup[idx_group]) > 1
            throw(ErrorException("scalar moment has multiple idxs"))
        end
        idx_moment = nkphopt.mmtloss_group_lookup[idx_group][1]
        mmt = nkphopt.nkphtgts.mmtloss_arr[idx_moment]
        var_i = nkphopt.df_target_covs[idx_group, :var_i]
        var_j = nkphopt.df_target_covs[idx_group, :var_j]
        cov_type = nkphopt.df_target_covs[idx_group, :cov_type]
        coeff_type = nkphopt.df_target_covs[idx_group, :coeff_type]
        wgts = nkphopt.df_target_covs[idx_group, :wgts]
        moment_vec = [var_i, var_j, cov_type, coeff_type, wgts, mmt.bhat[1], mmt.b[1]]
        moment_mat[row_idx+=1, :] = moment_vec
    end
    df_moment_mat = DataFrame(moment_mat, [:var_i, :var_j, :cov_type, :coeff_type,
        :wgts, :bhat, :b])
    return df_moment_mat
end

function _get_nonscalar_moments(nkphopt::NKPHEstimationOptimizer)
    # collect all non-scalar moments in table
    nonscalar_mmt_idxs = _get_nonscalar_moment_indices(nkphopt)
    N_nonscalar_moments = length(nonscalar_mmt_idxs)

    df_mmts_arr = Vector{DataFrame}(undef, N_nonscalar_moments)
    title_str_arr = Vector{String}(undef, N_nonscalar_moments)
    row_idx = 0
    for idx_group in nonscalar_mmt_idxs
        var_i = nkphopt.df_target_covs[idx_group, :var_i]
        var_j = nkphopt.df_target_covs[idx_group, :var_j]
        idx_condcoeff = nkphopt.df_target_covs[idx_group, :idx_condcoeff]

        if var_i == var_j
            title_str = var_i
        else
            title_str = var_i * "_" * var_j * "_" * string(idx_condcoeff)
        end
        title_str_arr[row_idx+=1] = title_str
        df_mmts_arr[row_idx] = _compare_model_moment(idx_group, nkphopt)
    end
    return df_mmts_arr, title_str_arr
end

function save_moments(xls_fname::String, nkphopt::NKPHEstimationOptimizer)
    # save all scalar/nonscalar moments to xls file
    xls_fpath = OUTPUT_DIR * xls_fname * ".xlsx"
    df_mmts_scalar = _get_scalar_moments(nkphopt)
    if size(df_mmts_scalar, 1) > 0
        # save
        XLSX.writetable(xls_fpath, _df2xls(df_mmts_scalar)...; sheetname="mmts_scalar")
    end

    df_mmts_arr, title_str_arr = _get_nonscalar_moments(nkphopt)
    for (df_mmts_nonscalar, title_str) in zip(df_mmts_arr, title_str_arr)
        # save
        XLSX.writetable(xls_fpath, _df2xls(df_mmts_nonscalar)...;
            overwrite=true, sheetname=title_str
        )
    end

    return nothing
end


function _get_model_params(nkphopt::NKPHEstimationOptimizer)
    # combine estimated and fixed params into table
    df_params_estimated = copy(nkphopt.df_params_constraints[:, 1:2])
    #df_params_estimated[:, :estimated] .= "estimated"
    rename!(df_params_estimated, "init" => "val")
    df_params_estimated[:, :val] = nkphopt.nkphsolvr.x_dict.vals[1:nkphopt.nkphsolvr.N_params]

    df_params_fixed = nkphopt.df_params_fixed[:, 1:2]
    #df_params_fixed[:, :estimated] .= "fixed"
    df_params_all = vcat(df_params_estimated, df_params_fixed)
    return df_params_all
end
